{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Use GLM-4 API to fine-tune model with LORA method\n",
    "\n",
    "**This tutorial is available in English and is attached below the Chinese explanation**\n",
    "\n",
    "此代码将展示如何调用 GLM-4 API 来使用LORA的方法对一个 glm 模型进行微调。你将会在本代码中学到：\n",
    "1. 微调数据集的格式\n",
    "2. 如何上传微调的数据集\n",
    "3. 如何管理和查看自己的微调数据集\n",
    "4. 如何设置超参数并开始微调模型\n",
    "5. 如何调用微调后的模型\n",
    "\n",
    "This code will show how to call the GLM-4 API to fine-tune a glm model using LORA methods. What you will learn in this code:\n",
    "1. Fine-tune the format of the data set\n",
    "2. How to upload fine-tuned dataset\n",
    "3. How to manage and view your own fine-tuning data set\n",
    "4. How to set hyperparameters and start fine-tuning the model\n",
    "5. How to call the fine-tuned model"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "62da0bf568d4e89f"
  },
  {
   "cell_type": "markdown",
   "source": [
    "首先，我们设置好 API KEY，并将这个API KEY设置为环境变量，这样我们就可以在代码中直接调用了。\n",
    "\n",
    "First, we set up the API KEY and set this API KEY as an environment variable so that we can call it directly in the code."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "ceb4590d0e68f177"
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [],
   "source": [
    "import os\n",
    "from zhipuai import ZhipuAI\n",
    "\n",
    "os.environ[\"ZHIPUAI_API_KEY\"] = \"your api key\"\n",
    "client = ZhipuAI()"
   ],
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2024-01-26T06:01:51.691702Z",
     "start_time": "2024-01-26T06:01:51.474710Z"
    }
   },
   "id": "initial_id"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Prepare the dataset\n",
    "\n",
    "需要准备训练的数据集，并将其拆分为训练集和验证集。我是用 [AdvertiseGen 数据集](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view) 并将其对其和自动拆分成了训练集和验证集。\n",
    "我们先下载数据集并将其放到 `data` 文件夹下并解压缩，接着运行如下脚本来转换数据集格式。\n",
    "由于 Zhipu AI的数据集格式必须按照[官方文档](https://open.bigmodel.cn/dev/howuse/finetuning)的格式进行编排,因此，整理完后的数据集格式如下：\n",
    "\n",
    "The training data set needs to be prepared and split into a training set and a validation set. I used the [AdvertiseGen data set](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view) and automatically split it into a training set and a validation set.\n",
    "We first download the data set and place it in the `data` folder and unzip it, then run the following script to convert the data set format.\n",
    "Since the data set format of Zhipu AI must be arranged according to the [official document](https://open.bigmodel.cn/dev/howuse/finetuning), the format of the completed data set is as follows:\n",
    "\n",
    "```\n",
    "{\"messages\": [{\"role\": \"user\", \"content\": \"类型#裤*版型#宽松*风格#性感*图案#线条*裤型#阔腿裤\"}, {\"role\": \"assistant\", \"content\": \"宽松的阔腿裤这两年真的吸粉不少，明星时尚达人的心头爱。毕竟好穿时尚，谁都能穿出腿长2米的效果宽松的裤腿，当然是遮肉小能手啊。上身随性自然不拘束，面料亲肤舒适贴身体验感棒棒哒。系带部分增加设计看点，还让单品的设计感更强。腿部线条若隐若现的，性感撩人。颜色敲温柔的，与裤子本身所呈现的风格有点反差萌。\"}]}\n",
    "{\"messages\": [{\"role\": \"user\", \"content\": \"类型#裙*风格#简约*图案#条纹*图案#线条*图案#撞色*裙型#鱼尾裙*裙袖长#无袖\"}, {\"role\": \"assistant\", \"content\": \"圆形领口修饰脖颈线条，适合各种脸型，耐看有气质。无袖设计，尤显清凉，简约横条纹装饰，使得整身人鱼造型更为生动立体。加之撞色的鱼尾下摆，深邃富有诗意。收腰包臀,修饰女性身体曲线，结合别出心裁的鱼尾裙摆设计，勾勒出自然流畅的身体轮廓，展现了婀娜多姿的迷人姿态。\"}]}\n",
    "\n",
    "```"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "9915962fb3f28662"
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "from typing import Union\n",
    "from pathlib import Path\n",
    "import json\n",
    "\n",
    "\n",
    "def advertise_gen_convert(data_dir: Union[str, Path], save_dir: Union[str, Path]):\n",
    "    def _convert(in_file: Path, out_file: Path):\n",
    "        with open(in_file, encoding='utf-8') as fin, open(out_file, 'w', encoding='utf-8') as fout:\n",
    "            for line in fin:\n",
    "                item = json.loads(line)\n",
    "                sample = {'messages': [{'role': 'user', 'content': item['content']},\n",
    "                                       {'role': 'assistant', 'content': item['summary']}]}\n",
    "                fout.write(json.dumps(sample, ensure_ascii=False) + '\\n')\n",
    "\n",
    "    data_dir = Path(data_dir).resolve()\n",
    "    save_dir = Path(save_dir).resolve()\n",
    "\n",
    "    for file_name in ['train.json', 'dev.json']:\n",
    "        in_file = data_dir / file_name\n",
    "        out_file = save_dir / f'{in_file.stem}.jsonl'\n",
    "\n",
    "        if in_file.is_file():\n",
    "            out_file.parent.mkdir(parents=True, exist_ok=True)\n",
    "            _convert(in_file, out_file)\n",
    "\n",
    "\n",
    "advertise_gen_convert(data_dir=\"data/AdvertiseGen\", save_dir=\"data/AdvertiseGen_Formatted\")"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-01-26T06:01:59.301189Z",
     "start_time": "2024-01-26T06:01:59.252924Z"
    }
   },
   "id": "b043656cf5cc4721"
  },
  {
   "cell_type": "markdown",
   "source": [
    "将数据集转换结束后，通过SDK提供的接口上传数据集到Zhipu AI的服务器上，并查看上传的数据集的文件列表。\n",
    "\n",
    "After the data set is converted, upload the data set to Zhipu AI's server through the interface provided by the SDK, and view the file list of the uploaded data set."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "b74d4cf6b5455bc5"
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [
    {
     "data": {
      "text/plain": "ListOfFileObject(object='list', data=[FileObject(id='file-20240126055309882-wnpcq', bytes=60066225, created_at=1706248392000, filename='train.jsonl', object='fd', purpose='fine-tune', status=None, status_details=None), FileObject(id='file-20240126055312192-6zlhw', bytes=557244, created_at=1706248392000, filename='dev.jsonl', object='fd', purpose='fine-tune', status=None, status_details=None), FileObject(id='file-20240126032709833-5zrkr', bytes=557244, created_at=1706239630000, filename='dev.jsonl', object='fd', purpose='fine-tune', status=None, status_details=None), FileObject(id='file-20240126032453929-hm9mh', bytes=60066225, created_at=1706239496000, filename='train.jsonl', object='fd', purpose='fine-tune', status=None, status_details=None)], has_more=None)"
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "client.files.create(\n",
    "    file=open(\"data/AdvertiseGen_Formatted/train.jsonl\", \"rb\"),\n",
    "    purpose=\"fine-tune\"\n",
    ")\n",
    "\n",
    "client.files.create(\n",
    "    file=open(\"data/AdvertiseGen_Formatted/dev.jsonl\", \"rb\"),\n",
    "    purpose=\"fine-tune\"\n",
    ")\n",
    "client.files.list()"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-01-26T06:02:01.264253Z",
     "start_time": "2024-01-26T06:02:01.132029Z"
    }
   },
   "id": "4e2f3421a6f6953f"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Start training using the uploaded file\n",
    "\n",
    "现在，在列表中已经找到了训练集和验证集的两个文件，我们设置好相关的参数，并上传到 Zhipu AI 的服务器上开始训练。按照[官方文档](https://open.bigmodel.cn/dev/api#finetuning)填写相关参数。我们就可以开训练了。我选择使用最小的模型 `chatglm3-6b` 进行训练。\n",
    "\n",
    "<span style='color: red; font-weight: bold;'>⚠️ 当前 ZhipuAI API 还不支持 更大参数的模型训练，训练方法也仅支持 LORA ⚠️</span>\n",
    "\n",
    "当前版本SDK中，默认使用的是 LORA 的方式进行训练。由于超参数在文档中仅开放了三个参数，分别是\n",
    "- `learning_rate_multiplier` 学习率的倍数\n",
    "- `batch_size` 批次大小\n",
    "- `n_epochs` 训练轮数\n",
    "\n",
    "接着，我们就能开始训练了，并且，我们打印了训练的`job_id`，以便后续查看训练进度。\n",
    "\n",
    "\n",
    "Now that we have found the two files of the training set and validation set in the list, we set the relevant parameters and uploaded them to the Zhipu AI server to start training. Fill in the relevant parameters according to [official documentation](https://open.bigmodel.cn/dev/api#finetuning). We can start training. I chose to use the smallest model `chatglm3-6b` for training.\n",
    "\n",
    "<span style='color: red; font-weight: bold;'>⚠️ Currently ZhipuAI API does not support model training with larger parameters, and the training method only supports LORA ⚠️</span>\n",
    "\n",
    "In the current version of the SDK, the LORA method is used for training by default. Since the hyperparameters only open three parameters in the document, they are\n",
    "- `learning_rate_multiplier` multiple of the learning rate\n",
    "- `batch_size` batch size\n",
    "- `n_epochs` Number of training epochs\n",
    "These values have default values in the documentation, but I chose to set them myself.\n"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "27c6eec7eead67f8"
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [],
   "source": [
    "# set more hyperparameters\n",
    "# job = client.fine_tuning.jobs.create(\n",
    "#     model=\"chatglm3-6b\",\n",
    "#     training_file=\"file-20240126055312192-6zlhw\",\n",
    "#     validation_file=\"file-20240126055312192-6zlhw\",\n",
    "#     hyperparameters={\n",
    "#         \"learning_rate_multiplier\": 0.5,\n",
    "#         \"batch_size\": 32,\n",
    "#         \"n_epochs\": 2,\n",
    "#     },\n",
    "#     suffix=\"advertise_gen\",\n",
    "# )\n",
    "\n",
    "# using default hyperparameters\n",
    "job = client.fine_tuning.jobs.create(\n",
    "    model=\"chatglm3-6b\",\n",
    "    training_file=\"file-20240126055312192-6zlhw\",\n",
    "    validation_file=\"file-20240126055312192-6zlhw\",\n",
    ")\n",
    "job_id = job.id"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-01-26T06:02:04.085283Z",
     "start_time": "2024-01-26T06:02:04.076408Z"
    }
   },
   "id": "ca08f2b278867ce0"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 查看训练进度\n",
    "我们可以查看这个训练任务的进度，以及训练的状态。 由于训练任务是异步的，因此，我们需要等待一段时间后才能看到训练的状态。其中\n",
    "- `status` 训练状态，现在是 `running`，表示正在训练\n",
    "\n",
    "由于本数据集的训练机数量较大（超过11万条数据），训练时间较长（在书写本cookbook时微调时间超过了1个小时), 因此，我在实际微调的过程使用，使用的是 `dev.jsonl` 文件进行微调和验证\n",
    "您可以自己删除部分数据集，这样能有效降低训练时间和训练的token消耗成本。\n",
    "\n",
    "We can check the progress of this training task and the status of training. Since the training task is asynchronous, we need to wait for a period of time before we can see the status of the training. in\n",
    "- `status` training status, now `running`, indicating training is in progress\n",
    "\n",
    "Since the number of training machines in this data set is large (more than 110,000 pieces of data) and the training time is long (the fine-tuning time took more than 1 hour when writing this cookbook), therefore, in the actual fine-tuning process, I used `dev.jsonl` file for fine-tuning and verification\n",
    "You can delete part of the data set yourself, which can effectively reduce training time and training token consumption costs."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "671dddd814cdb55d"
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "id='ftjob-20240126135812634-4m65v' request_id=None created_at=None error=None fine_tuned_model='ft:chatglm3-6b:advertise_gen:sg59pfxk' finished_at=None hyperparameters=None model='chatglm3-6b' object='fine_tuning.job' result_files=[] status='running' trained_tokens=None training_file='file-20240126055312192-6zlhw' validation_file='file-20240126055312192-6zlhw'\n"
     ]
    }
   ],
   "source": [
    "fine_tuning_job = client.fine_tuning.jobs.retrieve(fine_tuning_job_id=job_id)"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-01-26T06:02:39.797873Z",
     "start_time": "2024-01-26T06:02:39.646808Z"
    }
   },
   "id": "bd40ce03fe32fd6e"
  },
  {
   "cell_type": "markdown",
   "source": [
    "通过等待一段时间我们再次查询，我们发现，模型已经训练完成，接下来，我们就是调用模型进行测试了。\n",
    "通过 Python SDK来查询我所有的微调任务进度，其中，这些字段的含义如下：\n",
    "- `id` 微调任务的id\n",
    "- `fine_tuned_model` 微调模型的名称\n",
    "- `request_id` 用户请求的id，默认是空\n",
    "- `created_at`,`finished_at` 任务的开始时间和结束时间,在这里都是空\n",
    "- `status` 任务的状态，现在是 `running`，表示正在训练\n",
    "\n",
    "By waiting for a while and querying again, we find that the model has been trained. Next, we call the model for testing.\n",
    "Use the Python SDK to check the progress of all my fine-tuning tasks."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "ffbe2fd50dbddc9e"
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "outputs": [
    {
     "data": {
      "text/plain": "ListOfFineTuningJob(object='list', data=[FineTuningJob(id='ftjob-20240126135812634-4m65v', request_id=None, created_at=None, error=None, fine_tuned_model='ft:chatglm3-6b:advertise_gen:sg59pfxk', finished_at=None, hyperparameters=None, model='chatglm3-6b', object='fine_tuning.job', result_files=['file-20240126060112746-zqkss'], status='succeeded', trained_tokens=None, training_file='file-20240126055312192-6zlhw', validation_file='file-20240126055312192-6zlhw'), FineTuningJob(id='ftjob-20240126113041678-cs6s9', request_id=None, created_at=None, error=None, fine_tuned_model='ft:chatglm3-6b:advertise_gen:dkgf9mc8', finished_at=None, hyperparameters=None, model='chatglm3-6b', object='fine_tuning.job', result_files=[], status='failed', trained_tokens=None, training_file='file-20240126032453929-hm9mh', validation_file='file-20240126032709833-5zrkr')], has_more=None)"
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "client.fine_tuning.jobs.list()"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-01-26T06:18:09.407830Z",
     "start_time": "2024-01-26T06:18:09.264230Z"
    }
   },
   "id": "8704ce3273581614"
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 调用微调后的模型\n",
    "\n",
    "在完成微调任务之后，我们会得到一个新的模型，我们可以通过SDK来完成对这个模型的调用。我们首先查看一下模型列表，然后选择最新的模型进行调用。使用微调后的模型的请求方式与常规的 GLM 模型没有区别，仅需将 `model` 更换为刚才微调的模型名称即可。\n",
    "\n",
    "After completing the fine-tuning task, we will get a new model, and we can complete the call to this model through the SDK. We first check the model list, and then select the latest model to call. The request method for using the fine-tuned model is no different from the regular GLM model. You only need to replace `model` with the name of the model you just fine-tuned.\n"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "e5dcb2f0413c8590"
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "outputs": [
    {
     "data": {
      "text/plain": "'这件裤子采用了简约的英伦风格设计，展现出了古典与现代的完美结合。它采用了合身的剪裁，让您的身形更加挺拔优雅。裤子的面料使用了高品质的棉质面料，让您在穿着时更加舒适自在。此外，这条裤子还采用了细节处理，如精细的缝线和纽扣，展现了品质卓越的特点。整体来看，这条裤子既简约又具有时尚感，是您英伦风格搭配的理想选择。'"
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "response = client.chat.completions.create(\n",
    "    model=\"ft:chatglm3-6b:advertise_gen:sg59pfxk\",\n",
    "    messages=[\n",
    "        {\"role\": \"user\", \"content\": \"类型#裤*风格#英伦*风格#简约\"},\n",
    "    ],\n",
    ")\n",
    "response.choices[0].message.content"
   ],
   "metadata": {
    "collapsed": false,
    "ExecuteTime": {
     "end_time": "2024-01-26T06:18:39.760763Z",
     "start_time": "2024-01-26T06:18:37.865028Z"
    }
   },
   "id": "71656127b9ae9da8"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
